{
 "cells": [
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "---\n",
    "skip_exec: true\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp vision.widgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.torch_basics import *\n",
    "from fastai.data.all import *\n",
    "from fastai.vision.core import *\n",
    "from fastcore.parallel import *\n",
    "from ipywidgets import HBox,VBox,widgets,Button,Checkbox,Dropdown,Layout,Box,Output,Label,FileUpload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "_all_ = ['HBox','VBox','widgets','Button','Checkbox','Dropdown','Layout','Box','Output','Label','FileUpload']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vision widgets\n",
    "\n",
    "> ipywidgets for images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@patch\n",
    "def __getitem__(self:Box, i): return self.children[i]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def widget(im, *args, **layout) -> Output:\n",
    "    \"Convert anything that can be `display`ed by `IPython` into a widget\"\n",
    "    o = Output(layout=merge(*args, layout))\n",
    "    with o: display(im)\n",
    "    return o"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "#### widget\n",
       "\n",
       ">      widget (im, *args, **layout)\n",
       "\n",
       "Convert anything that can be `display`ed by `IPython` into a widget"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(widget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c8de2bd3f2d43ff8a00713b47d13782",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='Puppy'), Output(layout=Layout(max_width='192px'))))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "im = Image.open('images/puppy.jpg').to_thumb(256,512)\n",
    "VBox([widgets.HTML('Puppy'),\n",
    "      widget(im, max_width=\"192px\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _update_children(\n",
    "    change:dict # A dictionary holding the information about the changed widget\n",
    "):\n",
    "    \"Sets a value to the `layout` attribute on widget initialization and change\"\n",
    "    for o in change['owner'].children:\n",
    "        if not o.layout.flex: o.layout.flex = '0 0 auto'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def carousel(\n",
    "    children:tuple|MutableSequence=(), # `Box` objects to display in a carousel\n",
    "    **layout\n",
    ") -> Box: # An `ipywidget`'s carousel\n",
    "    \"A horizontally scrolling carousel\"\n",
    "    def_layout = dict(overflow='scroll hidden', flex_flow='row', display='flex')\n",
    "    res = Box([], layout=merge(def_layout, layout))\n",
    "    res.observe(_update_children, names='children')\n",
    "    res.children = children\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "#### carousel\n",
       "\n",
       ">      carousel (children:Union[tuple,list]=(), **layout)\n",
       "\n",
       "A horizontally scrolling carousel\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| children | tuple \\| list | () | `Box` objects to display in a carousel |\n",
       "| layout |  |  |  |\n",
       "| **Returns** | **Box** |  | **An `ipywidget`'s carousel** |"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(carousel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dbaf8c50de1d436ebb7178a5a42673cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Box(children=(VBox(children=(Output(layout=Layout(max_width='192px')), Button(description='click', style=Butto…"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ts = [VBox([widget(im, max_width='192px'), Button(description='click')])\n",
    "      for o in range(3)]\n",
    "\n",
    "carousel(ts, width='450px')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _open_thumb(\n",
    "    fn:Path|str, # A path of an image\n",
    "    h:int, # Thumbnail Height\n",
    "    w:int # Thumbnail Width\n",
    ") -> Image: # `PIL` image to display\n",
    "    \"Opens an image path and returns the thumbnail of the image\"\n",
    "    return Image.open(fn).to_thumb(h, w).convert('RGBA')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class ImagesCleaner:\n",
    "    \"A widget that displays all images in `fns` along with a `Dropdown`\"\n",
    "    def __init__(self,\n",
    "        opts:tuple=(), # Options for the `Dropdown` menu\n",
    "        height:int=128, # Thumbnail Height\n",
    "        width:int=256, # Thumbnail Width\n",
    "        max_n:int=30 # Max number of images to display\n",
    "    ):\n",
    "        opts = ('<Keep>', '<Delete>')+tuple(opts)\n",
    "        store_attr('opts,height,width,max_n')\n",
    "        self.widget = carousel(width='100%')\n",
    "\n",
    "    def set_fns(self,\n",
    "        fns:list # Contains a path to each image \n",
    "    ):\n",
    "        \"Sets a `thumbnail` and a `Dropdown` menu for each `VBox`\"\n",
    "        self.fns = L(fns)[:self.max_n]\n",
    "        ims = parallel(_open_thumb, self.fns, h=self.height, w=self.width, progress=False,\n",
    "                       n_workers=min(len(self.fns)//10,defaults.cpus))\n",
    "        self.widget.children = [VBox([widget(im, height=f'{self.height}px'), Dropdown(\n",
    "            options=self.opts, layout={'width': 'max-content'})]) for im in ims]\n",
    "\n",
    "    def _ipython_display_(self): display(self.widget)\n",
    "    def values(self) -> list:\n",
    "        \"Current values of `Dropdown` for each `VBox`\"\n",
    "        return L(self.widget.children).itemgot(1).attrgot('value')\n",
    "    def delete(self) -> list:\n",
    "        \"Indices of items to delete\"\n",
    "        return self.values().argwhere(eq('<Delete>'))\n",
    "    def change(self) -> list:\n",
    "        \"Tuples of the form (index of item to change, new class)\"\n",
    "        idxs = self.values().argwhere(not_(in_(['<Delete>','<Keep>'])))\n",
    "        return idxs.zipwith(self.values()[idxs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "### ImagesCleaner\n",
       "\n",
       ">      ImagesCleaner (opts:'tuple'=(), height:'int'=128, width:'int'=256,\n",
       ">                     max_n:'int'=30)\n",
       "\n",
       "A widget that displays all images in `fns` along with a `Dropdown`\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| opts | tuple | () | Options for the `Dropdown` menu |\n",
       "| height | int | 128 | Thumbnail Height |\n",
       "| width | int | 256 | Thumbnail Width |\n",
       "| max_n | int | 30 | Max number of images to display |"
      ],
      "text/plain": [
       "<nbdev.showdoc.BasicMarkdownRenderer>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ImagesCleaner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "deb65e2b8a9342c78df55a8a31186426",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Box(children=(VBox(children=(Output(layout=Layout(height='128px')), Dropdown(layout=Layout(width='max-content'…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fns = get_image_files('images')\n",
    "w = ImagesCleaner(('A','B'))\n",
    "w.set_fns(fns)\n",
    "w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((#0) [], (#0) [])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w.delete(),w.change()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_iw_info(\n",
    "    learn,\n",
    "    ds_idx:int=0 # Index in `learn.dls`\n",
    ") -> list:\n",
    "    \"For every image in `dls` `zip` it's `Path`, target and loss\"\n",
    "    dl = learn.dls[ds_idx].new(shuffle=False, drop_last=False)\n",
    "    probs,targs,preds,losses = learn.get_preds(dl=dl, with_input=False, with_loss=True, with_decoded=True)\n",
    "    targs = [dl.vocab[t] for t in targs]\n",
    "    return L([dl.dataset.items,targs,losses]).zip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(ImagesCleaner)\n",
    "class ImageClassifierCleaner(GetAttr):\n",
    "    \"A widget that provides an `ImagesCleaner` for a CNN `Learner`\"\n",
    "    def __init__(self, learn, **kwargs):\n",
    "        vocab = learn.dls.vocab\n",
    "        self.default = self.iw = ImagesCleaner(vocab, **kwargs)\n",
    "        self.dd_cats = Dropdown(options=vocab)\n",
    "        self.dd_ds   = Dropdown(options=('Train','Valid'))\n",
    "        self.iwis = _get_iw_info(learn,0),_get_iw_info(learn,1)\n",
    "        self.dd_ds.observe(self.on_change_ds, 'value')\n",
    "        self.dd_cats.observe(self.on_change_ds, 'value')\n",
    "        self.on_change_ds()\n",
    "        self.widget = VBox([self.dd_cats, self.dd_ds, self.iw.widget])\n",
    "\n",
    "    def _ipython_display_(self): display(self.widget)\n",
    "    def on_change_ds(self,change=None):\n",
    "        \"Toggle between training validation set view\"\n",
    "        info = L(o for o in self.iwis[self.dd_ds.index] if o[1]==self.dd_cats.value)\n",
    "        self.iw.set_fns(info.sorted(2, reverse=True).itemgot(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/vision/widgets.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### ImageClassifierCleaner\n",
       "\n",
       ">      ImageClassifierCleaner (learn, opts:tuple=(), height:int=128,\n",
       ">                              width:int=256, max_n:int=30)\n",
       "\n",
       "*A widget that provides an `ImagesCleaner` for a CNN `Learner`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| learn |  |  |  |\n",
       "| opts | tuple | () | Options for the `Dropdown` menu |\n",
       "| height | int | 128 | Thumbnail Height |\n",
       "| width | int | 256 | Thumbnail Width |\n",
       "| max_n | int | 30 | Max number of images to display |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/vision/widgets.py#L108){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### ImageClassifierCleaner\n",
       "\n",
       ">      ImageClassifierCleaner (learn, opts:tuple=(), height:int=128,\n",
       ">                              width:int=256, max_n:int=30)\n",
       "\n",
       "*A widget that provides an `ImagesCleaner` for a CNN `Learner`*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| learn |  |  |  |\n",
       "| opts | tuple | () | Options for the `Dropdown` menu |\n",
       "| height | int | 128 | Thumbnail Height |\n",
       "| width | int | 256 | Thumbnail Width |\n",
       "| max_n | int | 30 | Max number of images to display |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(ImageClassifierCleaner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
